import os
from pathlib import Path
from typing import Dict, Tuple

import matplotlib.pyplot as plt
import pandas as pd


# -------------------------------------------------------------------
# Helpers
# -------------------------------------------------------------------

def infer_prefix_from_dir(log_dir: Path) -> str:
    """Infer a short prefix from a directory name, e.g. '3_loggers' -> '3'."""
    name = log_dir.name
    return name.split("_")[0] if "_" in name else name


# -------------------------------------------------------------------
# Plotting
# -------------------------------------------------------------------

def plot_rel_rmse(
    rel_rmse: Dict[str, float],
    dataset_title: str,
    outfile: str,
    dpi: int,
) -> None:
    """Make a per-dataset bar chart of relative RMSE values."""
    try:
        plt.style.use("seaborn-whitegrid")
    except Exception:
        # fallback for older Matplotlib versions
        plt.style.use("seaborn-v0_8-whitegrid")

    # 5 methods total (including DR-π*)
    keys_order = [
        "J_naive_ips",
        "J_balanced_ips",
        "J_weighted_ips",
        "J_dr_balanced_ips",
        "J_optimal_ips",
        "J_dr_optimal_ips"
    ]
    labels_math = [
        r"IPS",
        r"bIPS",
        r"wIPS",
        r"DR-bIPS",
        r"oIPS",
        r"DR-oIPS"
    ]

    # Only keep methods that actually appear in this CSV
    keys_present = [k for k in keys_order if k in rel_rmse]
    values = [rel_rmse[k] for k in keys_present]
    labels = [labels_math[keys_order.index(k)] for k in keys_present]

    fig, ax = plt.subplots(figsize=(7, 4.5), dpi=dpi)
    ax.set_yscale("log")

    colors = ["tab:red", "tab:blue", "tab:purple", 
              "tab:brown", "tab:green", "tab:orange"]
    colors = colors[: len(values)]

    bars = ax.bar(
        labels,
        values,
        color=colors,
        edgecolor="black",
        linewidth=0.8,
    )

    # annotate bars with numeric values
    for bar in bars:
        height = bar.get_height()
        if height <= 0:
            continue
        ax.annotate(
            f"{height:.3f}",
            (bar.get_x() + bar.get_width() / 2.0, height),
            xytext=(0, 4),
            textcoords="offset points",
            ha="center",
            va="bottom",
            fontsize=9,
        )

    ax.set_title(dataset_title, fontweight="bold")
    ax.set_ylabel("Relative RMSE (log scale)")

    if values:
        positive_values = [v for v in values if v > 0]
        if positive_values:
            ymin = min(positive_values) / 1.8
            ymax = max(values) * 1.8
            ax.set_ylim(ymin, ymax)

    ax.tick_params(axis="x", labelsize=11)
    ax.tick_params(axis="y", labelsize=11)
    ax.set_axisbelow(True)
    ax.yaxis.grid(True, which="both", linestyle="--", linewidth=0.6)

    plt.tight_layout()
    os.makedirs(os.path.dirname(outfile), exist_ok=True)
    plt.savefig(outfile, dpi=dpi)
    plt.close(fig)


# -------------------------------------------------------------------
# CSV parsing
# -------------------------------------------------------------------

def parse_filename_metadata(csv_path: Path) -> Tuple[str, bool]:
    """Extract dataset and estimate_pi_b flag from filename written by utils.save_rel_rmses.

    Example filename:
      rel_rmse__dataset=pendigits__estimate_pi_b=false__num_loggers=3__n_fold=5__....
    """
    name = csv_path.name
    if not name.startswith("rel_rmse__"):
        return "unknown", False

    core = name[len("rel_rmse__") :]
    if core.endswith(".csv"):
        core = core[:-4]

    dataset = "unknown"
    estimate_pi_b = False

    for part in core.split("__"):
        if part.startswith("dataset="):
            dataset = part.split("=", 1)[1]
        elif part.startswith("estimate_pi_b="):
            value = part.split("=", 1)[1].lower()
            estimate_pi_b = value == "true"

    return dataset, estimate_pi_b


def load_rel_rmse_dict(csv_path: Path) -> Dict[str, float]:
    """Load a CSV into {method: rel_rmse}."""
    df = pd.read_csv(csv_path)
    rel_rmse: Dict[str, float] = {}
    for _, row in df.iterrows():
        rel_rmse[str(row["method"])] = float(row["rel_rmse"])
    return rel_rmse


# -------------------------------------------------------------------
# Main
# -------------------------------------------------------------------

def main(args) -> None:
    input_dir = Path(args.log_dir)

    if not input_dir.exists():
        print(f"Directory not found: {input_dir}")
        return

    csv_files = sorted(input_dir.glob(args.file_glob))
    if not csv_files:
        print(f"No CSV files found in {input_dir} matching pattern {args.file_glob!r}")
        return

    file_prefix = (
        args.prefix
        if args.prefix is not None and args.prefix != ""
        else infer_prefix_from_dir(input_dir)
    )

    title_map = {
        "letter": "Letter",
        "optdigits": "OptDigits",
        "pendigits": "PenDigits",
        "sat": "SatImage",
    }

    for csv_path in csv_files:
        dataset, estimate_pi_b = parse_filename_metadata(csv_path)
        rel_rmse = load_rel_rmse_dict(csv_path)

        dataset_title = title_map.get(dataset, dataset.title())
        suffix = "estimated" if estimate_pi_b else "known"
        outfile = input_dir / f"{file_prefix}-{dataset}-{suffix}.{args.image_format}"

        plot_rel_rmse(rel_rmse, dataset_title, str(outfile), dpi=args.dpi)
        print(f"Saved figure: {outfile}")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description="Plot per-dataset Relative RMSE bar charts for multi-logger experiments."
    )
    parser.add_argument(
        "--log_dir",
        type=str,
        default="log/3_loggers",
        help="Directory containing rel_rmse__*.csv files.",
    )
    parser.add_argument(
        "--file_glob",
        type=str,
        default="rel_rmse__*.csv",
        help="Glob pattern of CSV files to read within log_dir.",
    )
    parser.add_argument(
        "--prefix",
        type=str,
        default=None,
        help="Filename prefix, e.g. '3'. If omitted, inferred from log_dir like '3_loggers' -> '3'.",
    )
    parser.add_argument(
        "--format",
        dest="image_format",
        type=str,
        default="png",
        choices=["png", "pdf", "svg"],
        help="Output image format.",
    )
    parser.add_argument(
        "--dpi",
        type=int,
        default=200,
        help="Figure DPI for saved images.",
    )

    cli_args = parser.parse_args()
    main(cli_args)
